{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:03.268465Z",
     "start_time": "2024-08-12T09:37:00.169686Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from sklearn.metrics import accuracy_score"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:03.304383Z",
     "start_time": "2024-08-12T09:37:03.269474Z"
    }
   },
   "cell_type": "code",
   "source": [
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)"
   ],
   "id": "c55aaf38afc149d6",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.9.1\n",
      "numpy 1.26.4\n",
      "pandas 2.2.2\n",
      "sklearn 1.5.1\n",
      "torch 2.4.0+cu121\n",
      "cuda:0\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:04.200681Z",
     "start_time": "2024-08-12T09:37:03.304887Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor"
   ],
   "id": "7f5f757fd3b2418d",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:04.233910Z",
     "start_time": "2024-08-12T09:37:04.200681Z"
    }
   },
   "cell_type": "code",
   "source": [
    "train_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "\n",
    "test_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")"
   ],
   "id": "9168323f2edf16",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:04.239931Z",
     "start_time": "2024-08-12T09:37:04.234919Z"
    }
   },
   "cell_type": "code",
   "source": [
    "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "val_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)"
   ],
   "id": "8550e6b7d318d8f8",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:04.249457Z",
     "start_time": "2024-08-12T09:37:04.240778Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(784, 300),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(300, 100),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(100, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "\n",
    "model = NeuralNetwork()"
   ],
   "id": "932c4ca332c0a10d",
   "outputs": [],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:07.126141Z",
     "start_time": "2024-08-12T09:37:04.250472Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torch.utils.tensorboard import SummaryWriter"
   ],
   "id": "90d8171b8068fb40",
   "outputs": [],
   "execution_count": 7
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:07.130911Z",
     "start_time": "2024-08-12T09:37:07.126141Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class TensorBoardCallback:\n",
    "    def __init__(self, log_dir, flush_secs=10):\n",
    "        self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)\n",
    "\n",
    "    def draw_model(self, model, input_shape):\n",
    "        self.writer.add_graph(model, input_to_model=torch.randn(input_shape))\n",
    "\n",
    "    def add_loss_scalars(self, step, loss, val_loss):\n",
    "        self.writer.add_scalar(\n",
    "            tag='train/loss',\n",
    "            scalar_value={'loss': loss, 'val_loss': val_loss},\n",
    "            global_step=step\n",
    "        )\n",
    "\n",
    "    def add_acc_scalars(self, step, acc, val_acc):\n",
    "        self.writer.add_scalar(\n",
    "            tag='train/accuracy',\n",
    "            scalar_value={'accuracy': acc, 'val_acc': val_acc},\n",
    "            global_step=step\n",
    "        )\n",
    "\n",
    "    def add_lr_scalars(self, step, learning_rate):\n",
    "        self.writer.add_scalar(\n",
    "            tag='train/learning_rate',\n",
    "            scalar_value={'learning_rate': learning_rate},\n",
    "            global_step=step\n",
    "        )\n",
    "\n",
    "    def __call__(self, step, **kwargs):\n",
    "        #loss曲线\n",
    "        loss = kwargs.pop('loss', None)\n",
    "        val_loss = kwargs.pop('val_loss', None)\n",
    "        if loss is not None and val_loss is not None:\n",
    "            self.add_loss_scalars(step, loss, val_loss)\n",
    "        #acc曲线\n",
    "        acc = kwargs.pop('acc', None)\n",
    "        val_acc = kwargs.pop('val_acc', None)\n",
    "        if acc is not None and val_acc is not None:\n",
    "            self.add_acc_scalars(step, acc, val_acc)\n",
    "        #lr曲线\n",
    "        learning_rate = kwargs.pop('lr', None)\n",
    "        if learning_rate is not None:\n",
    "            self.add_lr_scalars(step, learning_rate)"
   ],
   "id": "636f9f01db4eedeb",
   "outputs": [],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:07.139934Z",
     "start_time": "2024-08-12T09:37:07.130911Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        self.save_dir = save_dir\n",
    "        self.save_step = save_step\n",
    "        self.save_best_only = save_best_only\n",
    "        self.best_metrics = -1\n",
    "\n",
    "        if not os.path.exists(self.save_dir):\n",
    "            os.mkdir(self.save_dir)\n",
    "\n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step != 0:\n",
    "            return\n",
    "        if self.save_best_only:\n",
    "            assert metric is not None\n",
    "            if metric >= self.best_metrics:\n",
    "                self.best_metrics = metric\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, 'best.ckpt'))\n",
    "        else:\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{第step步}.ckpt\"))"
   ],
   "id": "e8f9e4788b9bc7f4",
   "outputs": [],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:07.148957Z",
     "start_time": "2024-08-12T09:37:07.140921Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class EarlyStopCallback:\n",
    "    def __init__(self, patience=5, min_delta=0.01):\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.best_metric = -1\n",
    "        self.counter = 0\n",
    "\n",
    "    def __call__(self, metric):\n",
    "        if metric >= self.best_metric + self.min_delta:\n",
    "            self.best_metric = metric\n",
    "            self.counter = 0\n",
    "        else:\n",
    "            self.counter += 1\n",
    "\n",
    "    @property\n",
    "    def early_stop(self):\n",
    "        return self.counter > self.patience"
   ],
   "id": "b9eaa15955d420a1",
   "outputs": [],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:07.156642Z",
     "start_time": "2024-08-12T09:37:07.149465Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def training(model, train_loader, val_loader,\n",
    "             epoch, optmizer, loss_fct,\n",
    "             eval_step=500, tensorbard_callback=None, save_ckpt_callback=None,\n",
    "             early_stop_callback=None):\n",
    "    record_dict = {'train': [], 'val': []}\n",
    "    global_step = 0\n",
    "    #开启训练模式\n",
    "    model.train()\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "                optmizer.zero_grad()\n",
    "                logits = model(datas)\n",
    "                loss = loss_fct(logits, labels)\n",
    "                #反向传播\n",
    "                loss.backward()\n",
    "                optmizer.step()\n",
    "                #记录\n",
    "                preds = logits.argmax(axis=-1)\n",
    "                acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())\n",
    "                loss = loss.cpu().item()\n",
    "                record_dict['train'].append({\n",
    "                    'accuracy': acc,\n",
    "                    'loss': loss,\n",
    "                    'step': global_step\n",
    "                })\n",
    "                #验证\n",
    "                if global_step % eval_step == 0:\n",
    "                    model.eval()\n",
    "                    val_loss, val_acc = evaluating(model, val_loader, loss_fct)\n",
    "                    record_dict['val'].append({\n",
    "                        'loss': val_loss,\n",
    "                        'accuracy': acc,\n",
    "                        'step': global_step\n",
    "                    })\n",
    "                    model.train()\n",
    "                    #可视化\n",
    "                    if tensorbard_callback is not None:\n",
    "                        tensorbard_callback(global_step,\n",
    "                                            loss=loss, val_loss=val_loss, acc=acc, val_acc=val_acc,\n",
    "                                            lr=optmizer.param_groups[0][\"lr\"])\n",
    "                    #模型保存\n",
    "                    if save_ckpt_callback is not None:\n",
    "                        save_ckpt_callback(global_step,\n",
    "                                           model.state_dict(),\n",
    "                                           metric=val_acc\n",
    "                                           )\n",
    "                    #早停\n",
    "                    if early_stop_callback is not None:\n",
    "                        if early_stop_callback(val_acc) is True:\n",
    "                            print(f\"Early stop at epoch {epoch_id}/global_step{global_step}\")\n",
    "                            return record_dict\n",
    "\n",
    "                global_step += 1\n",
    "                pbar.update(1)\n",
    "                pbar.set_postfix({'epoch': epoch_id})\n",
    "    return record_dict"
   ],
   "id": "8ed4226674232267",
   "outputs": [],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:07.167339Z",
     "start_time": "2024-08-12T09:37:07.157646Z"
    }
   },
   "cell_type": "code",
   "source": [
    "@torch.no_grad()\n",
    "def evaluating(model, dataloader, loss_fct):\n",
    "    loss_list = []\n",
    "    pred_list = []\n",
    "    label_list = []\n",
    "    for datas, labels in dataloader:\n",
    "        datas = datas.to(device)\n",
    "        labels = labels.to(device)\n",
    "        logits = model(datas)\n",
    "        loss = loss_fct(logits, labels)\n",
    "        preds = logits.argmax(axis=-1)\n",
    "        loss_list.append(loss.item())\n",
    "        pred_list.extend(preds.cpu().numpy().tolist())\n",
    "        label_list.extend(labels.cpu().numpy().tolist())\n",
    "    acc = accuracy_score(label_list, pred_list)\n",
    "    return np.mean(loss_list), acc"
   ],
   "id": "174824854f01613d",
   "outputs": [],
   "execution_count": 12
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-12T09:37:10.948436Z",
     "start_time": "2024-08-12T09:37:07.168343Z"
    }
   },
   "cell_type": "code",
   "source": [
    "loss_fct = nn.CrossEntropyLoss()\n",
    "optmizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
    "tensorboard_callback = TensorBoardCallback(\"runs\")\n",
    "tensorboard_callback.draw_model(model, [1, 28, 28])\n",
    "save_ckpt_callback = SaveCheckpointsCallback(\"checkpoints\", save_best_only=True)\n",
    "early_stop_callback = EarlyStopCallback(patience=10)\n",
    "epoch = 100\n",
    "#开始训练\n",
    "model = model.to(device)\n",
    "record = training(model, train_loader, val_loader, epoch, optmizer, loss_fct,\n",
    "                  tensorbard_callback=tensorboard_callback,\n",
    "                  save_ckpt_callback=save_ckpt_callback,\n",
    "                  early_stop_callback=early_stop_callback,\n",
    "                  eval_step=1000\n",
    "                  )"
   ],
   "id": "812d470d9cf65150",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/187500 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1be77fe3f1cc462d8b1432c89f18accd"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "NotImplementedError",
     "evalue": "Got <class 'dict'>, but numpy array or torch tensor are expected.",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mNotImplementedError\u001B[0m                       Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[13], line 10\u001B[0m\n\u001B[0;32m      8\u001B[0m \u001B[38;5;66;03m#开始训练\u001B[39;00m\n\u001B[0;32m      9\u001B[0m model \u001B[38;5;241m=\u001B[39m model\u001B[38;5;241m.\u001B[39mto(device)\n\u001B[1;32m---> 10\u001B[0m record \u001B[38;5;241m=\u001B[39m \u001B[43mtraining\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_loader\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_loader\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mepoch\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43moptmizer\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mloss_fct\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     11\u001B[0m \u001B[43m                  \u001B[49m\u001B[43mtensorbard_callback\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtensorboard_callback\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     12\u001B[0m \u001B[43m                  \u001B[49m\u001B[43msave_ckpt_callback\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msave_ckpt_callback\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     13\u001B[0m \u001B[43m                  \u001B[49m\u001B[43mearly_stop_callback\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mearly_stop_callback\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     14\u001B[0m \u001B[43m                  \u001B[49m\u001B[43meval_step\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1000\u001B[39;49m\n\u001B[0;32m     15\u001B[0m \u001B[43m                  \u001B[49m\u001B[43m)\u001B[49m\n",
      "Cell \u001B[1;32mIn[11], line 41\u001B[0m, in \u001B[0;36mtraining\u001B[1;34m(model, train_loader, val_loader, epoch, optmizer, loss_fct, eval_step, tensorbard_callback, save_ckpt_callback, early_stop_callback)\u001B[0m\n\u001B[0;32m     39\u001B[0m \u001B[38;5;66;03m#可视化\u001B[39;00m\n\u001B[0;32m     40\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m tensorbard_callback \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m---> 41\u001B[0m     \u001B[43mtensorbard_callback\u001B[49m\u001B[43m(\u001B[49m\u001B[43mglobal_step\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     42\u001B[0m \u001B[43m                        \u001B[49m\u001B[43mloss\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mloss\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_loss\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mval_loss\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43macc\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43macc\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_acc\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mval_acc\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     43\u001B[0m \u001B[43m                        \u001B[49m\u001B[43mlr\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43moptmizer\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mparam_groups\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mlr\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     44\u001B[0m \u001B[38;5;66;03m#模型保存\u001B[39;00m\n\u001B[0;32m     45\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m save_ckpt_callback \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n",
      "Cell \u001B[1;32mIn[8], line 34\u001B[0m, in \u001B[0;36mTensorBoardCallback.__call__\u001B[1;34m(self, step, **kwargs)\u001B[0m\n\u001B[0;32m     32\u001B[0m val_loss \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mval_loss\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;28;01mNone\u001B[39;00m)\n\u001B[0;32m     33\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m loss \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m val_loss \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m---> 34\u001B[0m     \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43madd_loss_scalars\u001B[49m\u001B[43m(\u001B[49m\u001B[43mstep\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mloss\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_loss\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     35\u001B[0m \u001B[38;5;66;03m#acc曲线\u001B[39;00m\n\u001B[0;32m     36\u001B[0m acc \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124macc\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;28;01mNone\u001B[39;00m)\n",
      "Cell \u001B[1;32mIn[8], line 9\u001B[0m, in \u001B[0;36mTensorBoardCallback.add_loss_scalars\u001B[1;34m(self, step, loss, val_loss)\u001B[0m\n\u001B[0;32m      8\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21madd_loss_scalars\u001B[39m(\u001B[38;5;28mself\u001B[39m, step, loss, val_loss):\n\u001B[1;32m----> 9\u001B[0m     \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mwriter\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43madd_scalar\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m     10\u001B[0m \u001B[43m        \u001B[49m\u001B[43mtag\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mtrain/loss\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[0;32m     11\u001B[0m \u001B[43m        \u001B[49m\u001B[43mscalar_value\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m{\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mloss\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[43mloss\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mval_loss\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_loss\u001B[49m\u001B[43m}\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     12\u001B[0m \u001B[43m        \u001B[49m\u001B[43mglobal_step\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mstep\u001B[49m\n\u001B[0;32m     13\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32mD:\\Users\\JW\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\utils\\tensorboard\\writer.py:378\u001B[0m, in \u001B[0;36mSummaryWriter.add_scalar\u001B[1;34m(self, tag, scalar_value, global_step, walltime, new_style, double_precision)\u001B[0m\n\u001B[0;32m    351\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Add scalar data to summary.\u001B[39;00m\n\u001B[0;32m    352\u001B[0m \n\u001B[0;32m    353\u001B[0m \u001B[38;5;124;03mArgs:\u001B[39;00m\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m    374\u001B[0m \n\u001B[0;32m    375\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m    376\u001B[0m torch\u001B[38;5;241m.\u001B[39m_C\u001B[38;5;241m.\u001B[39m_log_api_usage_once(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtensorboard.logging.add_scalar\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m--> 378\u001B[0m summary \u001B[38;5;241m=\u001B[39m \u001B[43mscalar\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m    379\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtag\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mscalar_value\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnew_style\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnew_style\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdouble_precision\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mdouble_precision\u001B[49m\n\u001B[0;32m    380\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    381\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_get_file_writer()\u001B[38;5;241m.\u001B[39madd_summary(summary, global_step, walltime)\n",
      "File \u001B[1;32mD:\\Users\\JW\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\utils\\tensorboard\\summary.py:371\u001B[0m, in \u001B[0;36mscalar\u001B[1;34m(name, tensor, collections, new_style, double_precision)\u001B[0m\n\u001B[0;32m    354\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mscalar\u001B[39m(name, tensor, collections\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, new_style\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m, double_precision\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m):\n\u001B[0;32m    355\u001B[0m \u001B[38;5;250m    \u001B[39m\u001B[38;5;124;03m\"\"\"Output a `Summary` protocol buffer containing a single scalar value.\u001B[39;00m\n\u001B[0;32m    356\u001B[0m \n\u001B[0;32m    357\u001B[0m \u001B[38;5;124;03m    The generated Summary has a Tensor.proto containing the input Tensor.\u001B[39;00m\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m    369\u001B[0m \u001B[38;5;124;03m      ValueError: If tensor has the wrong shape or type.\u001B[39;00m\n\u001B[0;32m    370\u001B[0m \u001B[38;5;124;03m    \"\"\"\u001B[39;00m\n\u001B[1;32m--> 371\u001B[0m     tensor \u001B[38;5;241m=\u001B[39m \u001B[43mmake_np\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtensor\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241m.\u001B[39msqueeze()\n\u001B[0;32m    372\u001B[0m     \u001B[38;5;28;01massert\u001B[39;00m (\n\u001B[0;32m    373\u001B[0m         tensor\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m    374\u001B[0m     ), \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mTensor should contain one element (0 dimensions). Was given size: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtensor\u001B[38;5;241m.\u001B[39msize\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m and \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtensor\u001B[38;5;241m.\u001B[39mndim\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m dimensions.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    375\u001B[0m     \u001B[38;5;66;03m# python float is double precision in numpy\u001B[39;00m\n",
      "File \u001B[1;32mD:\\Users\\JW\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\utils\\tensorboard\\_convert_np.py:25\u001B[0m, in \u001B[0;36mmake_np\u001B[1;34m(x)\u001B[0m\n\u001B[0;32m     23\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(x, torch\u001B[38;5;241m.\u001B[39mTensor):\n\u001B[0;32m     24\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m _prepare_pytorch(x)\n\u001B[1;32m---> 25\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\n\u001B[0;32m     26\u001B[0m     \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mGot \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mtype\u001B[39m(x)\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, but numpy array or torch tensor are expected.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m     27\u001B[0m )\n",
      "\u001B[1;31mNotImplementedError\u001B[0m: Got <class 'dict'>, but numpy array or torch tensor are expected."
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [],
   "id": "556d0e8e067ca182",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
